%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SVM Decoding
%
% GE_Step1e_diff_clusterstats.m
%
% Statistical evaluation using permutation tests
%
% Modified from GE_Step1d_clusterstats.m
%   by David Sorensen
%   --Cluster analysis on the difference in classifier accuracy between two conditions, e.g. words vs nonwords
%

close all;
clear all;

svm_dir = '/autofs/space/clive_001/users/adriana/GE_SVM';
scripts_dir = sprintf('%s/scripts', svm_dir);
addpath(genpath(scripts_dir));

addpath(genpath('/autofs/space/clive_001/users/adriana/UAG_SVM/scripts/fusionlab_toolbox'));

% Specify the list of subjects:

SubjectNames =  {'GE_05', 'GE_06', 'GE_07', 'GE_08', 'GE_09', 'GE_10', 'GE_11', 'GE_12', 'GE_13', 'GE_14', 'GE_15', 'GE_16', 'GE_17', 'GE_18', 'GE_19', 'GE_20', 'GE_21', 'GE_22', 'GE_23', 'GE_24'};
n_subjects = length(SubjectNames);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Specify the name of the analysis type:

analysis_tag = ''

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Specify the Regions of Interest ("ROIs", also called "labels":

roiset_array = {'R_postCG_2-rh', 'R_ITG_2-rh', 'R_ITG_1-rh', 'R_MTG_2-rh', 'R_MTG_1-rh', 'R_STG_1-rh', 'R_STG_2-rh', 'R_STG_3-rh', 'L_postCG_1-lh', 'L_ITG_2-lh', 'L_MTG_2-lh', 'L_STG_1-lh'};
n_roisets = length(roiset_array);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Specify the condition pairs:
condition_tag = 'lvn'; % 'lvn'|'pos'

switch condition_tag
	case 'lvn'
		%Separately train lexical and nonword neighbors; test seeds
		categories = {'nonwords'; 'words'}
		train_conditionA_array = {'100','100','100','100','100','200','200','200','200','300','300','300','400','400','500';...
									'10','10','10','10','10','20','20','20','20','30','30','30','40','40','50'};
		train_conditionB_array = {'200','300','400','500','600','300','400','500','600','400','500','600','500','600','600';...
									'20','30','40','50','60','30','40','50','60','40','50','60','50','60','60'};
		test_conditionA_array = {'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5'};
		test_conditionB_array = {'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6'};
	case 'pos'
		%Separately train based on differential position; test seeds
		categories = {'initial'; 'vowel'; 'final'};
		train_conditionA_array = {'1001', '1001', '1001', '1001', '1001', '2001', '2001', '2001', '2001', '3001', '3001', '3001', '4001', '4001', '5001';...
									'1002', '1002', '1002', '1002', '1002', '2002', '2002', '2002', '2002', '3002', '3002', '3002', '4002', '4002', '5002';...
									'1003', '1003', '1003', '1003', '1003', '2003', '2003', '2003', '2003', '3003', '3003', '3003', '4003', '4003', '5003'};
		train_conditionB_array = {'2001', '3001', '4001', '5001', '6001', '3001', '4001', '5001', '6001', '4001', '5001', '6001', '5001', '6001', '6001';...
									'2002', '3002', '4002', '5002', '6002', '3002', '4002', '5002', '6002', '4002', '5002', '6002', '5002', '6002', '6002';...
									'2003', '3003', '4003', '5003', '6003', '3003', '4003', '5003', '6003', '4003', '5003', '6003', '5003', '6003', '6003'};
		test_conditionA_array = {'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5'};
		test_conditionB_array = {'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6'};
    case 'half'
        %Divide data in half randomly (while maintaining proportion of words
        %to nonwords) for training
		categories = {'firsthalf'; 'secondhalf'};
        train_conditionA_array = {'101', '101', '101', '101', '101', '201', '201', '201', '201', '301', '301', '301', '401', '401', '501';...
                                    '102', '102', '102', '102', '102', '202', '202', '202', '202', '302', '302', '302', '402', '402', '502'};
        train_conditionB_array = {'201', '301', '401', '501', '601', '301', '401', '501', '601', '401', '501', '601', '501', '601', '601';...
                                    '202', '302', '402', '502', '602', '302', '402', '502', '602', '402', '502', '602', '502', '602', '602'};
        test_conditionA_array = {'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5'};
		test_conditionB_array = {'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6'};
    case 'third'
        %Divide data in thirds randomly (while maintaining proportion of
        %tokens changed at each position) for training
		categories = {'firstthird'; 'secondthird'; 'thirdthird'};
        train_conditionA_array = {'1010', '1010', '1010', '1010', '1010', '2010', '2010', '2010', '2010', '3010', '3010', '3010', '4010', '4010', '5010';...
									'1020', '1020', '1020', '1020', '1020', '2020', '2020', '2020', '2020', '3020', '3020', '3020', '4020', '4020', '5020';...
									'1030', '1030', '1030', '1030', '1030', '2030', '2030', '2030', '2030', '3030', '3030', '3030', '4030', '4030', '5030'};
		train_conditionB_array = {'2010', '3010', '4010', '5010', '6010', '3010', '4010', '5010', '6010', '4010', '5010', '6010', '5010', '6010', '6010';...
									'2020', '3020', '4020', '5020', '6020', '3020', '4020', '5020', '6020', '4020', '5020', '6020', '5020', '6020', '6020';...
									'2030', '3030', '4030', '5030', '6030', '3030', '4030', '5030', '6030', '4030', '5030', '6030', '5030', '6030', '6030'};
		test_conditionA_array = {'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5'};
		test_conditionB_array = {'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6'};
	case 'sop'
		%Position, but reverse training and testing (train on hubs, test on neighbors by position
		categories = {'initial_rev', 'vowel_rev', 'final_rev'};
end

n_condpairs = min(length(train_conditionA_array), length(train_conditionB_array));

cat_diffs = nchoosek(1:length(categories), 2);

% These are for plotting with subplots:
n_hubwords = 6;
n_cols = n_hubwords - 1;
n_rows = n_hubwords - 1;

num_permutation = 1000;

%alpha_array         = [0.05, 0.1];
%cluster_alpha_array = [0.05, 0.1];
alpha_array         = [0.05];
cluster_alpha_array = [0.05];
n_alpha = min(length(alpha_array), length(cluster_alpha_array));

%timepoints=1101;
window = [100:600];
timepoints = length(window);

%savefigs=1;
savefigs=0;

zeroed_data_cell=cell(length(categories),n_subjects);
zeroed_ave_data_cell=cell(2,n_subjects);

accuracy_all = NaN(length(categories), n_subjects, n_condpairs, timepoints); %store decoding accuracy for all condition pairs for all subjects

time_tag = datetime('now', 'Format', 'yyyyMMdd-HHmm');


for i_alpha = 1:n_alpha   % try different cluster statistics parameters
    this_alpha = alpha_array(i_alpha);
    this_cluster_alpha = cluster_alpha_array(i_alpha);

  for i_roi=1:n_roisets
    roiset=char(roiset_array(i_roi));
    
    for i_category=1:length(categories)
        
    	for i_condpair=1:n_condpairs

        	train_conditionA= char(train_conditionA_array(i_category, i_condpair));
        	train_conditionB= char(train_conditionB_array(i_category, i_condpair));
        	test_conditionA= char(test_conditionA_array(i_category, i_condpair));
        	test_conditionB= char(test_conditionB_array(i_category, i_condpair));

        	for i_subject=1:n_subjects
            	subject=char(SubjectNames(i_subject));
            	subject_dir = char(strcat(subject, analysis_tag));
            	input_dir = sprintf('%s/%s/results/', svm_dir, subject_dir);

            	results_dir=sprintf('%s/%s%s/results/',svm_dir,subject,analysis_tag);
        
            	% Get "accuracy" and "Time" from file:
            	inputfile = sprintf('%s/%s_%s_train%sand%s_test%svs%s_Accuracy.mat', input_dir, subject, roiset,train_conditionA,train_conditionB,test_conditionA,test_conditionB);
            	load(inputfile);
            	if exist('x')
                	accuracy = x;
                	Time = y;
            	end
            
            	accuracy_all(i_category, i_subject, i_condpair,:) = accuracy(Time>=window(1) & Time<=window(end)); %store decoding accuracy for all condition pairs for all subjects
           
        	end %i_subject
    	end %i_condpair
    end %i_cat 

    for i_diff = 1:size(cat_diffs, 1)
        for i_subject=1:n_subjects
	        zeroed_ave_data_cell{1,i_subject} = squeeze(mean(accuracy_all(cat_diffs(i_diff,1),i_subject, :, :)-accuracy_all(cat_diffs(i_diff,2), i_subject,:,:),3));
            zeroed_ave_data_cell{2,i_subject} = squeeze(mean(accuracy_all(cat_diffs(i_diff,2),i_subject, :, :)-accuracy_all(cat_diffs(i_diff,1), i_subject,:,:),3));
        end % i_subject

        stat_ave{i_diff, 1} = fl_permtestcluster(zeroed_ave_data_cell(1,:),'tail','right','statistic','tstat','verbose',0,'numpermutation',num_permutation,'alpha',this_alpha,'clusteralpha',this_cluster_alpha);	% right tail clusters (first category > second category)
		stat_ave{i_diff, 2} = fl_permtestcluster(zeroed_ave_data_cell(2,:),'tail','right','statistic','tstat','verbose',0,'numpermutation',num_permutation,'alpha',this_alpha,'clusteralpha',this_cluster_alpha); % left tail clusters (second category > first category)

        figure();
        accuracy_all_ave(1,:) = mean(mean(accuracy_all(cat_diffs(i_diff, 1), :, :, :), 3), 2); 
        accuracy_all_ave(2,:) = mean(mean(accuracy_all(cat_diffs(i_diff, 2), :, :, :), 3), 2); 
        accuracy_all_std(1,:) = std(mean(accuracy_all(cat_diffs(i_diff, 1), :, :, :), 3), 1); 
        accuracy_all_std(2,:) = std(mean(accuracy_all(cat_diffs(i_diff, 2), :, :, :), 3), 1); 
        shadedErrorBar(window,accuracy_all_ave(1,:),accuracy_all_std(1,:),'lineprops','-b', 'patchSaturation', 0.16);
        hold on;
        shadedErrorBar(window,accuracy_all_ave(2,:),accuracy_all_std(2,:),'lineprops','-r', 'patchSaturation', 0.16);
		l1=plot(window, squeeze(accuracy_all_ave(1, :)), 'Color', 'b', 'DisplayName', categories{cat_diffs(i_diff, 1)});
		l2=plot(window, squeeze(accuracy_all_ave(2, :)), 'Color', 'r', 'DisplayName', categories{cat_diffs(i_diff, 2)});

    	ylabel('Decoding Accuracy (%)','fontsize',10)
    	xlabel('Time (msec)','fontsize',10)
    	x_min = -200;
    	x_max = 1100;
    	y_min = 30;
    	y_max = 70;
    	axis([x_min x_max y_min y_max]);
    	xticks(-100:100:1000);
    	yticks(y_min:5:y_max);
    	yline(50);
    	xline(0);

    	titletext = [roiset ':' categories{cat_diffs(i_diff,1)} 'vs' categories{cat_diffs(i_diff,2)}];
    	title(titletext, 'fontsize', 10,'Interpreter','none');

        for i_clust = 1:length(stat_ave{i_diff,1}.clusters) %right tail clusters
          iic = cell2mat(stat_ave{i_diff,1}.clusters(i_clust));
          harea = area([window(iic(1)) window(iic(end))], [100 100], 'FaceColor', [0,0,0.5], 'LineStyle', 'none');
          alpha(0.2);
        end % i_clust

		for i_clust = 1:length(stat_ave{i_diff,2}.clusters) %left tail clusters
          iic = cell2mat(stat_ave{i_diff,2}.clusters(i_clust));
          harea = area([window(iic(1)) window(iic(end))], [100 100], 'FaceColor', [0.5,0,0], 'LineStyle', 'none');
          alpha(0.2);
        end % i_clust

        hold off;
	
	    legend([l1, l2]);


	    if savefigs
	        fig_save_stem = sprintf('ClusterStats_%s_clusteralpha_%s_alpha_%s',roiset,string(this_cluster_alpha),string(this_alpha));
	        results_dir = sprintf('%s/GE_grand_average_results/%s-%s/%s',svm_dir,categories{cat_diffs(i_diff,1)},categories{cat_diffs(i_diff,2)}, time_tag);
	        [status, msg, msgID] = mkdir(results_dir);
	        results_fig = sprintf('%s/%s.fig', results_dir,fig_save_stem);
	        savefig(results_fig);
	        results_png = sprintf('%s/%s.png', results_dir,fig_save_stem);
	        print(gcf,[results_png],'-dpng','-r300');
	        stat_filename = sprintf('%s/cluster_stats_%s', results_dir, roiset_array{i_roi});
	        save(stat_filename, 'stat_ave');
	    end % if savefigs
       
    end % i_diff

  end %i_roi

end %i_alpha
